import os
import json
from tqdm import tqdm 
from concurrent.futures import ProcessPoolExecutor, as_completed

from utils import clone_repository

input_file_path = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/repo_list_76.json'
output_file_path = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/retrieval_benchmark.json'
cloned_repos_dir = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/cloned_repos'

class retrieval_benchmark_base():
    def __init__(self, input_raw_data_path, output_benchmark_path, cloned_repos_dir, pre_computed_benchmark_file_path=None, output_retrieval_results_file_path = None, pre_computed_retrieval_results_path=None, multi_processing = True):
        
        
        self.input_file_path = input_raw_data_path
        self.output_file_path = output_benchmark_path
        self.cloned_repos_dir = cloned_repos_dir
        self.output_retrieval_results_file_path = output_retrieval_results_file_path
        self.multi_processing = multi_processing
        
        if pre_computed_benchmark_file_path:
            with open(pre_computed_benchmark_file_path) as f:
                pre_computed_benchmark = json.load(f)
            self.retrieval_benchmark = pre_computed_benchmark
            print("Loaded pre-computed retrieval benchmark.")
        else:
            self.retrieval_benchmark = self.generate_retrieval_benchmark()
        
        self.len = len(self.retrieval_benchmark)
        print(f'Created retrieval benchmark with {self.len} repos.')
        
        self.target_link_accuracy = 0.0
        self.trajectory_accuracy = 0.0
        self.trajectory_length = 0.0
        self.trajectory_coverage = 0.0
        
        if pre_computed_retrieval_results_path:
            with open(pre_computed_retrieval_results_path) as f:
                pre_computed_retrieval_results = json.load(f)
            self.retrieval_results = pre_computed_retrieval_results
            print("Loaded pre-computed retrieval results.")
            self.evalute_pre_computed_retrieval_results()
        else:
            self.retrieval_results = self.generate_retrieval_results()
            
        self.final_results()
    
    def evalute_pre_computed_retrieval_results(self):
        for index in range(len(self.retrieval_results)):
            
            single_retrieval_result = self.retrieval_results[index]
            repo_name = single_retrieval_result['repo_name']
            for benchmark_data in self.retrieval_benchmark:
                if benchmark_data['repo_name'] == repo_name:
                    ## Get the index of the repo in the retrieval benchmark
                    benchmark_index = self.retrieval_benchmark.index(benchmark_data)
                    self.evaluate_trajectory(benchmark_index, single_retrieval_result['predicted_trajectory'])
                    self.evaluate_target_link(benchmark_index, single_retrieval_result['predicted_target_link'])
    
    def clone_all_repos(self):
        for repo_data in tqdm(self.retrieval_benchmark):
            repo_url = repo_data['repo_url']
            clone_repository(repo_url, self.cloned_repos_dir)
        
    def generate_retrieval_benchmark(self):
        with open(self.input_file_path) as f:
            repo_list = json.load(f)
            retrieval_benchmark = []
            for repo_data in repo_list:
                repo_benchmark_info = {}
                repo_benchmark_info['repo_name'] = repo_data['Repo_Name']
                repo_benchmark_info['repo_url'] = repo_data['Github_Url']
                default_branch = repo_data['default_branch']
                repo_benchmark_info['default_branch'] = default_branch

                ### Here we assume that the start of the trajectory is always the repo url without .git. However, when treating the benchmark, we use the readme path to check with the predicted target link or trajectory. 
                ground_truth_trajectory = repo_data['retrieval_trajectory']   
                ground_truth_trajectory[0] += '/blob/' + default_branch + '/' + 'README.md'                
                repo_benchmark_info['retrieval_trajectory'] = ground_truth_trajectory

                ### Since there may be multiple retrieval target links that could be used, they are seperated by "&&" in the data           
                retrieval_target_link  = repo_data['retrieval_target_link']
                
                if "&&" in retrieval_target_link:
                    retrieval_target_link = retrieval_target_link.split("&&")
                else:
                    retrieval_target_link = [retrieval_target_link]
                    
                if len(retrieval_target_link) == 1:  
                    if repo_benchmark_info['repo_url'] == repo_data['retrieval_target_link']: 
                        ### Here we assume that the start of the trajectory is always the repo url without .git. However, when treating the benchmark, we use the readme path to check with the predicted target link or trajectory. 
                        retrieval_target_link[0] += '/blob/' + default_branch + '/' + 'README.md'
                        
                repo_benchmark_info['retrieval_target_link'] = retrieval_target_link
                
                
                
                repo_benchmark_info['readme_content'] = repo_data['readme_content']
                repo_benchmark_info['repo_dir'] = os.path.join(self.cloned_repos_dir, repo_benchmark_info['repo_name'])
                
                retrieval_benchmark.append(repo_benchmark_info)
                
        with open(self.output_file_path, 'w') as f:
            json.dump(retrieval_benchmark, f, indent=4)
            
        return retrieval_benchmark
    
    
    def get_ground_truth_trajectory(self, index):
        return self.retrieval_benchmark[index]['retrieval_trajectory']
    
    def get_ground_truth_target_link(self, index):
        
        return self.retrieval_benchmark[index]['retrieval_target_link']
    
    def get_item(self, index):
        return self.retrieval_benchmark[index]
    
    def generate_single_retrieval_result(self, index):
        '''
        Given the index of the repo in the retrieval benchmark, generate the retrieval results and then evaluate them
        '''
        return {}

    def evaluate_trajectory(self, index, predicted_trajectories):
        '''
        Update the evaluation metric of target trajectory based on the retrieval results of the repo at the given index
        '''
        return predicted_trajectories
    
    def evaluate_target_link(self, index, predicted_target_link):
        '''
        Update the evaluation metric of target link based on the retrieval results of the repo at the given index
        '''
        ground_truth_target_link = self.get_ground_truth_target_link(index)
        if predicted_target_link in ground_truth_target_link:
            self.target_link_accuracy += 1
        return predicted_target_link

    def final_results(self):
        '''
        Returns the final result of the retrieval benchmark: the average of target link accuracy, trajectory accuracy, and trajectory length
        '''
        print(f'Average Target Link Accuracy: {self.target_link_accuracy/self.len}')
        print(f'Average Trajectory Accuracy: {self.trajectory_accuracy/self.len}')
        print(f'Average Trajectory Length: {self.trajectory_length/self.len}')
        print(f'Average Trajectory Coverage: {self.trajectory_coverage/self.len}')

        return self.target_link_accuracy/self.len, self.trajectory_accuracy/self.len, self.trajectory_length/self.len, self.trajectory_coverage/self.len
    
    def calculate_trajectory_coverage(self, predicted_trajectory, ground_truth_trajectory):
        '''
        Calculate the coverage of the predicted trajectory with respect to the ground truth trajectory
        '''
        coverage = 0.0
        for predicted_link in predicted_trajectory:
            if predicted_link in ground_truth_trajectory:
                coverage += 1
        coverage /= len(ground_truth_trajectory)
        return coverage
    

    def generate_retrieval_results(self):
        retrieval_results = []
        
        # Create a helper function at module level if needed for pickling issues.
        # Alternatively, if generate_single_retrieval_result doesn't reference self heavily,
        # you can directly use it. Here we assume it is sufficiently self-contained.
        # def worker(idx):
        #     return self.generate_single_retrieval_result(idx)
        
        if self.multi_processing:
            with ProcessPoolExecutor(max_workers=8) as executor:
                # Submit all tasks for indices in parallel.
                futures = {executor.submit(self.generate_single_retrieval_result, i): i for i in range(self.len)}
                
                # Use tqdm to provide a progress bar.
                for future in tqdm(as_completed(futures), total=self.len):
                    idx = futures[future]
                    try:
                        result = future.result()
                        retrieval_results.append(result)
                    except Exception as exc:
                        print(f'Index {idx} generated an exception: {exc}')
        else:
            # If not using multi-threading, just run the function directly.
            for i in tqdm(range(self.len)):
                result = self.generate_single_retrieval_result(i)
                retrieval_results.append(result)
        # Save the results as before.
        with open(str(self.output_retrieval_results_file_path), 'w') as f:
            json.dump(retrieval_results, f, indent=4)
            print(f"Saved retrieval results to {self.output_retrieval_results_file_path}")
        
        print("Generated retrieval results.")
        return retrieval_results
            
            

# def main(input_file_path, output_file_path):
#     if os.path.exists(output_file_path):
#         print(f'{output_file_path} already exists.')
#         with open(output_file_path) as f:
#             retrieval_benchmark = json.load(f)
#     else:
#         retrieval_benchmark = generate_retrieval_benchmark(input_file_path, output_file_path)


if __name__ == '__main__':
    retrieval_benchmark_class = retrieval_benchmark_base(input_file_path, output_file_path, cloned_repos_dir)
    retrieval_benchmark_class.clone_all_repos()
    
    
    
    

